Switch to unified view

a b/graph_algorithm/CrossValidation.java
1
package GraphAlgorithm;
2
3
import java.io.BufferedWriter;
4
import java.io.File;
5
import java.io.FileWriter;
6
import java.io.IOException;
7
import java.util.ArrayList;
8
import java.util.HashMap;
9
import java.util.List;
10
import java.util.Map;
11
import network.CommGeneticsLoader;
12
import network.DisGraph;
13
import network.SparseVector;
14
import util.DCNOMIMUMLSIDmap;
15
16
/**
17
 * Implements leave one out cross validation to test the performance of disease genetics prediction
18
 * @author zhengc
19
 *
20
 */
21
public class CrossValidation {
22
23
    private Map<String, Integer> generank = new HashMap<>();
24
    
25
    /**
26
     * Initiate and create a cross validation object 
27
     * @param cgg CommGeneticsGraph
28
     * @param dis query disease
29
     * @param genelist omim genes for query disease
30
     * @throws IOException
31
     */
32
    public CrossValidation(DisGraph cgg, String dis, List<String> genelist) throws IOException {
33
        String dis_id = null;
34
        if (DCNOMIMUMLSIDmap.dcnnameidmap.containsKey(dis)) {
35
            dis_id = DCNOMIMUMLSIDmap.dcnnameidmap.get(dis);
36
        } else {
37
            System.out.println("No such disease in this network!");
38
        }
39
        
40
        List<String> dislist = new ArrayList<>();
41
        dislist.add(dis);
42
        SparseVector seed = RandomWalk.createSeedVector(dislist);
43
        int dis_idx = CommGeneticsLoader.entry_index.get(dis_id);
44
        
45
        int count = 1;
46
        for (String gene : genelist) {
47
            System.out.println(count + " of " +  genelist.size());
48
            if (CommGeneticsLoader.entry_index.containsKey(gene)) {
49
                int gene_idx = CommGeneticsLoader.entry_index.get(gene);
50
                cgg.removeEdge(dis_idx, gene_idx);
51
                RandomWalk rw = new RandomWalk(cgg, seed, 0.2);
52
                Map<String, Integer> result = rw.getRWRank();
53
                int rank = result.get(gene);
54
                generank.put(gene, rank);
55
                count++;
56
            }
57
        }
58
    }
59
    
60
    /**
61
     * Get disease OMIM gene ranking from leave one out cross validation 
62
     * @return disease OMIM gene ranking map
63
     */
64
    public Map<String, Integer> getCV() {
65
        return generank;
66
    }
67
    
68
    /** 
69
     * Write leave one out cross validation to a file
70
     * @param cv disease OMIM gene ranking map
71
     * @param cvfile a file to be written
72
     * @throws IOException
73
     */
74
    public static void saveCV(Map<String, Integer> cv, String cvfile) throws IOException {
75
        BufferedWriter bw = new BufferedWriter(new FileWriter(new File(cvfile)));
76
        bw.write("Gene" + ","  + "Rank" + "," + "Percentage" +  "\n");
77
        for (String gene : cv.keySet()) {
78
//          System.out.println(sortedgenemap.get(s));
79
            int rank = cv.get(gene);
80
            double perc = rank / CommGeneticsLoader.num_gene * 100;
81
            bw.write(gene + "," + rank + "," + perc + "\n");
82
        }
83
        bw.close();
84
    }
85
    
86
    public static void main(String[] args) throws IOException {
87
        
88
        System.out.println("Building network...");
89
        
90
        String dcnmapfile = "/Users/zhengc/workspace/FARES/data/FARES/map/term_umls_id_diso";
91
        String omimmapfile = "/Users/zhengc/workspace/FARES/data/OMIM/map/OMIM_umls_id_diso";
92
        util.DCNOMIMUMLSIDmap.createDCNIdNameMap(dcnmapfile);
93
        util.DCNOMIMUMLSIDmap.createOMIMIdNameMap(omimmapfile);
94
        
95
        String DCN_dis = "dementia";
96
        String OMIM_pat = "alzheimer";
97
        
98
        DCNOMIMUMLSIDmap domim = new DCNOMIMUMLSIDmap(DCN_dis, OMIM_pat);
99
        Map<String, List<String>> dcn_omim = domim.getDCNOMIMUMLSIDmap();
100
        
101
        /*
102
         * Build the bipartite network
103
         */
104
        String commnetfile = "/Users/zhengc/Projects/AD_comorbidity/data/fares_comm_net_conf_ISMB_final_public.txt";
105
        String dmnfile = "/Users/zhengc/Projects/AD_comorbidity/data/dmn_dm.txt";
106
        String ppifile = "/Users/zhengc/Projects/AD_comorbidity/data/gene_gene_string_cut.txt";
107
        String disgenefile = "/Users/zhengc/workspace/FARES/data/OMIM/mapped_OMIM/OMIM_disease_gene_umls_id_diso";
108
        DisGraph cgg = null;
109
        try {
110
//          cgg = CommGeneticsLoader.createCommGeneticsGraph(commnetfile, dmnfile, ppifile, disgenefile, dcn_omim);
111
            cgg = CommGeneticsLoader.createCommGeneticsGraph(commnetfile, ppifile, disgenefile, dcn_omim);
112
        } catch (IOException e) {
113
            // TODO Auto-generated catch block
114
            e.printStackTrace();
115
        }
116
        
117
        System.out.println("Done!");
118
        
119
        /*
120
         * Cross validation
121
         */
122
        System.out.println("Starting crossvalidate...");
123
        
124
        String genelistfile = "/Users/zhengc/Projects/AD_comorbidity/data/AD_omim_genes.txt";
125
        List<String> genelist = util.FileToList.createListFromfile(genelistfile);
126
        
127
        CrossValidation cv = new CrossValidation(cgg, DCN_dis, genelist);
128
        Map<String, Integer> generank = cv.getCV();
129
        
130
        for (String gene : generank.keySet()) {
131
            System.out.println(gene + ": " + generank.get(gene));
132
        }
133
        
134
        String cv_result_file = "/Users/zhengc/Projects/AD_comorbidity/results/evaluation/cv_result.csv";
135
        saveCV(generank, cv_result_file);
136
137
    }
138
}