Switch to unified view

a b/graph_algorithm/RandomWalk.java
1
package GraphAlgorithm;
2
3
import java.io.BufferedWriter;
4
import java.io.File;
5
import java.io.FileWriter;
6
import java.io.IOException;
7
import java.util.ArrayList;
8
import java.util.Collections;
9
import java.util.Comparator;
10
import java.util.HashMap;
11
import java.util.LinkedHashMap;
12
import java.util.LinkedList;
13
import java.util.List;
14
import java.util.Map;
15
import java.util.Set;
16
import network.CommGeneticsLoader;
17
import network.DisGraph;
18
import network.SparseMatrix;
19
import network.SparseVector;
20
import util.DCNOMIMUMLSIDmap;
21
import util.FileToList;
22
23
/**
24
 * Implement random walk with restart algorithm using sparse matrix data structure
25
 * @author zhengc
26
 *
27
 */
28
public class RandomWalk {
29
    
30
//  static final double P = 0.15;
31
//  static final double P = 0.15;
32
    static final double err = 0.000001;
33
    private Map<String, Integer> generank;
34
    
35
    /**
36
     * Constructs a new RandomWalk object from a graph and seed vector
37
     * @param cgg a CommGeneticsGraph 
38
     * @param E the seed vector
39
     * @param p restart rate
40
     */
41
    public RandomWalk(DisGraph cgg, SparseVector E, double p) {
42
        Map<String, Double> genescore = new HashMap<>();
43
        int size = cgg.getNodes();
44
        SparseVector x = new SparseVector(size);
45
        SparseMatrix A = cgg.getNet();
46
        A = normalizeMatrix(A);
47
        int iter = 0;
48
        SparseVector last_x = E;
49
        while((last_x.plus(x.scale(-1)).norm()) > err) {
50
            last_x = x;
51
            x = A.times(x).scale(1-p);
52
            x = x.plus(E.scale(p));
53
//          System.out.print(x);
54
            iter++;
55
//          System.out.println("Finish iteration " + iter + ":");
56
        }
57
        System.out.println("Total iteration " + iter);
58
        
59
        genescore = svToMap(x, E);
60
        generank = scoreToRank(genescore);
61
        
62
    }
63
    
64
    /**
65
     * Get the gene-rank pairs from random walk with restart
66
     * @return the gene-rank pairs from random walk with restart
67
     */
68
    public  Map<String, Integer> getRWRank() {
69
        return generank;
70
    }
71
    
72
    
73
    /**
74
     * Generate column/left stochastic matrix by normalization
75
     * @param comMatrix
76
     * @return normalized SparseMatrix
77
     */
78
    private static SparseMatrix normalizeMatrix(SparseMatrix comMatrix) {
79
        
80
        int n = comMatrix.size();
81
        double[] total_w = new double[n]; //Total weight of a column
82
        SparseMatrix A = new SparseMatrix(n);
83
        
84
        for (int j = 0; j < n; j++){
85
            total_w[j] = 0;
86
            for (int i = 0; i < n; i++){
87
                total_w[j] += comMatrix.get(i, j);
88
            }
89
        }
90
91
        
92
        //normalize the adjacency matrix
93
        for (int i = 0; i < n; i++){
94
            
95
            for (int j = 0; j < n; j++){
96
                double norm_w = comMatrix.get(i, j);
97
                norm_w = norm_w * 1.0 / total_w[j];
98
                A.put(i, j, norm_w);
99
            }
100
        }
101
        
102
        return A;
103
    }
104
    
105
    /**
106
     * Create a seed from a list
107
     * @param nodesids a list of seed nodes
108
     * @return the seed vector
109
     * @throws IOException
110
     */
111
    public static SparseVector createSeedVector(List<String> nodesids) throws IOException{
112
        
113
        int n = CommGeneticsLoader.entry_index.size();
114
        SparseVector E = new SparseVector(n);
115
116
        for (String id : nodesids) {
117
            if (CommGeneticsLoader.entry_index.containsKey(id)) {
118
                int index = CommGeneticsLoader.entry_index.get(id);
119
//              E.put(index, 1);
120
                E.put(index, 1.0 / nodesids.size());
121
            } else {
122
                System.out.println(id + " is NOT included in our network!");
123
            }
124
        }
125
        return E;
126
    }
127
    
128
    
129
    public static SparseVector createSeedVectorFromIdx(List<Integer> nodeidxs) {
130
        int n = CommGeneticsLoader.entry_index.size();
131
        SparseVector E = new SparseVector(n);
132
133
        for (int idx : nodeidxs) {
134
            if (idx > n - 1) {
135
                System.out.println(idx + " is out of network!");
136
            } else {
137
                E.put(idx, 1.0 / nodeidxs.size());;
138
            }
139
        }
140
        return E;
141
    }
142
    
143
    /**
144
     * Convert a SparseVector from random walk into a HashMap reversely ordered by value
145
     * @param sv the SparseVector representing gene rank score result
146
     * @param seeds_vector the seed vector
147
     * @return a HashMap reversely ordered by value 
148
     */
149
    private static Map<String, Double> svToMap(SparseVector sv, SparseVector seeds_vector) {
150
        Map<String, Double> genemap = new HashMap<String,Double>();
151
//      System.out.println(CommGeneticsLoader.num_disease);
152
//      System.out.println(CommGeneticsLoader.entry_list.size());
153
        for (int i = CommGeneticsLoader.num_disease; i < CommGeneticsLoader.entry_list.size(); i++){
154
//          double val_seed = seeds_vector.get(i);
155
            Set<Integer> seeds_key = seeds_vector.getKeys();
156
            if (!seeds_key.contains(i)) { // Test if the node is one of seeds
157
                double val = sv.get(i);
158
//              System.out.println(index_entry.get(num_disease));
159
                
160
                if (val != 0){
161
                    String gene = CommGeneticsLoader.index_entry.get(i);
162
                    genemap.put(gene, val);
163
                }
164
            }
165
        }
166
        genemap = util.CollectionsEx.sortByValueR(genemap);
167
        return genemap;
168
    }
169
    
170
    /**
171
     * Convert a gene rank score into a gene rank HashMap reversely ordered by value
172
     * @param sortedgenemap a reversely sorted gene score map
173
     * @return a gene rank  HashMap
174
     */
175
    private static Map<String, Integer> scoreToRank(Map<String, Double> sortedgenemap) {
176
        Map<String, Integer> generank = new HashMap<String,Integer>();
177
        int rank = 0;
178
        for (String s : sortedgenemap.keySet()) {
179
//          System.out.println(sortedgenemap.get(s));
180
            rank++;
181
            generank.put(s, rank);
182
            generank = util.CollectionsEx.sortByValue(generank);
183
        }
184
        
185
        return generank;
186
    }
187
    
188
//  System.out.println(genemap.size());
189
    
190
    
191
    
192
    /**
193
     * Write gene rank scores into CSV file, file format is "Rank, Gene, Score"
194
     * @param sortedgenemap HashMap sorted by descend values
195
     * @param rankfile file to be written
196
     * @throws IOException
197
     */
198
    public static void saveScorefile(Map<String, Double> sortedgenemap, String rankfile) throws IOException {
199
        
200
        BufferedWriter bw = new BufferedWriter(new FileWriter(new File(rankfile)));
201
        int rank = 0;
202
        bw.write("Rank" + "," + "Gene" + "," + "Score" + "," + "\n");
203
        for (String s : sortedgenemap.keySet()) {
204
//          System.out.println(sortedgenemap.get(s));
205
            Double score = sortedgenemap.get(s);
206
            rank++;
207
            bw.write(rank + "," + s + "," + score + "," + "\n");
208
        }
209
        bw.close();
210
    }
211
    
212
    /**
213
     * Write gene rank file into a CSV file, file format is "Rank, Gene"
214
     * @param generank gene rank HashMap
215
     * @param rankfile the file to be written
216
     * @throws IOException
217
     */
218
    public static void saveRankfile(Map<String, Integer> generank, String rankfile) throws IOException {
219
        
220
        BufferedWriter bw = new BufferedWriter(new FileWriter(new File(rankfile)));
221
        bw.write("Rank" + "," + "Gene" + ","  + "\n");
222
        for (String s : generank.keySet()) {
223
//          System.out.println(sortedgenemap.get(s));
224
            int rank = generank.get(s);
225
            bw.write(rank + "," + s + "," + "\n");
226
        }
227
        bw.close();
228
    }
229
    
230
    
231
    public static void main(String[] args) throws IOException {
232
        
233
        /*
234
         * Create additional DCN_OMIM map for give disease
235
         */
236
        String dcnmapfile = "./data/term_umls_id_diso";
237
        String omimmapfile = "./data/OMIM_umls_id_diso";
238
        util.DCNOMIMUMLSIDmap.createDCNIdNameMap(dcnmapfile);
239
        util.DCNOMIMUMLSIDmap.createOMIMIdNameMap(omimmapfile);
240
        
241
        String DCN_dis = "dementia";
242
        String OMIM_pat = "alzheimer";
243
        
244
        DCNOMIMUMLSIDmap domim = new DCNOMIMUMLSIDmap(DCN_dis, OMIM_pat);
245
        Map<String, List<String>> dcn_omim = domim.getDCNOMIMUMLSIDmap();
246
        
247
        /*
248
         * Build the bipartite network
249
         */
250
251
        String commnetfile = "./results/fares_comm_net_lift_final_abbr.txt";
252
        String ppifile = "./data/gene_gene_string_cut.txt";
253
        String disgenefile = "./data/OMIM_disease_gene_umls_id_diso";
254
        DisGraph cgg = null;
255
        try {
256
            cgg = CommGeneticsLoader.createCommGeneticsGraph(commnetfile, ppifile, disgenefile, dcn_omim);
257
        } catch (IOException e) {
258
            // TODO Auto-generated catch block
259
            e.printStackTrace();
260
        }
261
        
262
        /*
263
         * Create seed vector
264
         */
265
        int n = cgg.getNodes();
266
        
267
        List<String> seeds = new ArrayList<>();
268
        String seed_dis1 = "dementia";
269
        String seed_dis1_umls = null;
270
        if (util.DCNOMIMUMLSIDmap.dcnnameidmap.containsKey(seed_dis1)) {
271
            seed_dis1_umls = util.DCNOMIMUMLSIDmap.dcnnameidmap.get(seed_dis1);
272
            System.out.println(seed_dis1 + ": " + seed_dis1_umls);
273
            seeds.add(seed_dis1_umls);
274
        } else {
275
            System.out.println(seed_dis1 + " is NOT included in our network!");
276
        }
277
        
278
        String seed_genes_file = "./data/AD_omim_genes.txt";
279
        List<String> seed_genes = FileToList.createListFromfile(seed_genes_file);
280
        
281
        seeds.addAll(seed_genes);
282
//      diseases.add(disease2);
283
        SparseVector seeds_vec = new SparseVector(n);
284
        try {
285
            seeds_vec = createSeedVector(seeds);
286
        } catch (IOException e) {
287
            // TODO Auto-generated catch block
288
            e.printStackTrace();
289
        }
290
        System.out.println("Seed vector: " + seeds_vec);
291
        
292
        /*
293
         * Random walk
294
         */
295
296
        RandomWalk rwrank = new RandomWalk(cgg, seeds_vec, 0.5);
297
        Map<String, Integer> result = rwrank.getRWRank();
298
299
        /*
300
         * Process result and write to file
301
         */
302
303
        String rankfile = "./results/AD_novel_genes.csv";
304
        try {
305
            saveRankfile(result, rankfile);
306
        } catch (IOException e) {
307
            // TODO Auto-generated catch block
308
            e.printStackTrace();
309
        }
310
    }
311
}